
# fashion_mnist_znkd.py
#
# Full ZNKD-style pipeline for Fashion-MNIST:
#  - Teacher: 10-qubit VQC trained on true labels
#  - ZNE:    zero-noise energies per class (teacher)
#  - Targets: tanh-stabilized regression labels from ZNE-corrected energies
#  - Student: 6-qubit QNN trained via regression on those soft targets
#
# This script is intended as the full experimental pipeline for ZNKD on Fashion-MNIST.

import numpy as np
from tensorflow.keras.datasets import fashion_mnist
from sklearn.decomposition import PCA

from qiskit_aer import AerSimulator
from qiskit.utils import QuantumInstance
from qiskit.circuit.library import EfficientSU2

from qiskit_machine_learning.neural_networks import SamplerQNN
from qiskit_machine_learning.algorithms import VQC, NeuralNetworkRegressor

from zne_utils import zne_expectation_zero

SEED = 123
N_CLASSES = 10
TEACHER_QUBITS = 10
STUDENT_QUBITS = 6
PCA_COMPONENTS = TEACHER_QUBITS   # 10 -> 10 qubits


# =========================================================
# 1. Data loading & preprocessing (same as your scripts)
# =========================================================
def load_fashion_mnist_pca():
    (X_train, y_train), (X_test, y_test) = fashion_mnist.load_data()

    X_train = X_train.reshape(-1, 28 * 28).astype(np.float32) / 255.0
    X_test = X_test.reshape(-1, 28 * 28).astype(np.float32) / 255.0

    pca = PCA(n_components=PCA_COMPONENTS, random_state=SEED)
    X_train = pca.fit_transform(X_train)
    X_test = pca.transform(X_test)

    # Scale to [0, pi] for angle encoding
    X_train = np.pi * (X_train - X_train.min()) / (
        X_train.max() - X_train.min() + 1e-12
    )
    X_test = np.pi * (X_test - X_train.min()) / (
        X_train.max() - X_train.min() + 1e-12
    )

    return X_train, X_test, y_train, y_test


def make_feature_map(num_qubits):
    """Ry encoding on the first num_qubits features."""
    from qiskit import QuantumCircuit

    def _fm(x):
        qc = QuantumCircuit(num_qubits)
        for i, val in enumerate(x[:num_qubits]):
            qc.ry(float(val), i)
        return qc

    return _fm


# =========================================================
# 2. Teacher VQC
# =========================================================
def build_teacher_vqc(num_qubits=TEACHER_QUBITS, seed=SEED):
    ansatz = EfficientSU2(num_qubits, reps=2)

    qnn = SamplerQNN(
        circuit=ansatz,
        input_params=ansatz.parameters[:num_qubits],
        weight_params=ansatz.parameters[num_qubits:]
    )

    backend = AerSimulator(seed_simulator=seed)
    qinst = QuantumInstance(backend=backend, seed_simulator=seed,
                            seed_transpiler=seed)

    vqc = VQC(
        feature_map=make_feature_map(num_qubits),
        ansatz=ansatz,
        optimizer="COBYLA",
        quantum_instance=qinst,
        num_classes=N_CLASSES,
    )
    return vqc


# =========================================================
# 3. Student QNN (regressor on tanh(ZNE energies))
# =========================================================
def build_student_regressor(num_qubits=STUDENT_QUBITS, seed=SEED):
    ansatz = EfficientSU2(num_qubits, reps=1)

    qnn = SamplerQNN(
        circuit=ansatz,
        input_params=ansatz.parameters[:num_qubits],
        weight_params=ansatz.parameters[num_qubits:],
        sparse=False,
    )

    backend = AerSimulator(seed_simulator=seed)
    qinst = QuantumInstance(backend=backend, seed_simulator=seed,
                            seed_transpiler=seed)

    # NeuralNetworkRegressor performs regression on R^d targets
    regressor = NeuralNetworkRegressor(
        neural_network=qnn,
        loss="l2",
        optimizer="COBYLA",
        quantum_instance=qinst,
    )
    return regressor


# =========================================================
# 4. ZNE-based soft targets: tanh-stabilized energies
# =========================================================
def compute_zne_tanh_targets(vqc, X, base_eps=0.01, tau=1.0):
    """
    For each sample, use ZNE to obtain zero-noise energies for all classes,
    then apply a tanh-based stabilization to obtain regression targets.

    Returns:
        targets: np.ndarray of shape (n_samples, n_classes).
    """
    all_targets = []
    for i, x in enumerate(X):
        if (i + 1) % 100 == 0:
            print(f"[ZNE targets] Processed {i+1}/{len(X)} samples", flush=True)

        energies = zne_expectation_zero(vqc, x, base_eps=base_eps, seed=SEED)
        # Stabilize: map energies to a bounded range via tanh(· / tau)
        # Lower energy => higher confidence; we keep sign structure
        stabilized = np.tanh(energies / tau)
        all_targets.append(stabilized)

    return np.stack(all_targets, axis=0)


# =========================================================
# 5. Accuracy helpers (teacher vs. student)
# =========================================================
def teacher_accuracy(vqc, X, y):
    y_hat = vqc.predict(X)
    return (y_hat == y).mean()


def student_accuracy(student_regressor, X, y):
    # Student outputs a vector of stabilized energies per class.
    preds = student_regressor.predict(X)
    # Smaller energy => higher confidence; use argmin
    y_hat = np.argmin(preds, axis=1)
    return (y_hat == y).mean()


# =========================================================
# 6. Main ZNKD pipeline
# =========================================================
def main():
    print("Loading Fashion-MNIST + PCA …")
    X_train, X_test, y_train, y_test = load_fashion_mnist_pca()

    # Optional subsampling for faster demo runs
    # rng = np.random.RandomState(SEED)
    # idx_train = rng.choice(len(X_train), 5000, replace=False)
    # idx_test = rng.choice(len(X_test), 1000, replace=False)
    # X_train, y_train = X_train[idx_train], y_train[idx_train]
    # X_test, y_test = X_test[idx_test], y_test[idx_test]

    print("Building teacher VQC …")
    teacher = build_teacher_vqc()

    print("Training teacher on true labels …")
    teacher.fit(X_train, y_train)

    print("\n=== Teacher accuracy (noiseless simulator) ===")
    acc_teacher = teacher_accuracy(teacher, X_test, y_test)
    print(f"Teacher: {acc_teacher:.3f}")

    print("\nComputing ZNE-based tanh targets for distillation …")
    # base_eps can be tuned; 0.01 is a moderate depolarizing strength
    zne_targets = compute_zne_tanh_targets(teacher, X_train,
                                           base_eps=0.01, tau=1.0)

    print("Building student QNN regressor …")
    student = build_student_regressor()

    print("Training student on ZNE-tanh targets …")
    student.fit(X_train, zne_targets)

    print("\n=== Distilled student accuracy ===")
    acc_student = student_accuracy(student, X_test, y_test)
    print(f"Student (ZNKD regression): {acc_student:.3f}")


if __name__ == "__main__":
    main()
